// Copyright (c) 2023 Huawei Device Co., Ltd.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! `Resolver` trait and `DefaultDnsResolver` implementation.

use std::collections::HashMap;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use std::vec::IntoIter;

use crate::runtime::JoinHandle;

const DEFAULT_MAX_LEN: usize = 30000;
const DEFAULT_TTL: Duration = Duration::from_secs(60);

/// `SocketAddr` resolved by `Resolver`.
pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Sync + Send>;
/// Possible errors that this resolver may generate when attempting to
/// resolve.
pub type StdError = Box<dyn std::error::Error + Send + Sync>;
/// Futures generated by this resolve when attempting to resolve an address.
pub type SocketFuture<'a> =
    Pin<Box<dyn Future<Output = Result<Addrs, StdError>> + Sync + Send + 'a>>;

/// `Resolver` trait used by `async_impl::connector::HttpConnector`. `Resolver`
/// provides asynchronous dns resolve interfaces.
pub trait Resolver: Send + Sync + 'static {
    /// resolve authority to a `SocketAddr` `Future`.
    fn resolve(&self, authority: &str) -> SocketFuture;
}

/// `SocketAddr` resolved by `DefaultDnsResolver`.
pub struct ResolvedAddrs {
    iter: IntoIter<SocketAddr>,
}

impl ResolvedAddrs {
    pub(super) fn new(iter: IntoIter<SocketAddr>) -> Self {
        Self { iter }
    }

    // The first ip in the dns record is the preferred addrs type.
    pub(super) fn split_preferred_addrs(self) -> (Vec<SocketAddr>, Vec<SocketAddr>) {
        // get preferred address family type.
        let is_ipv6 = self
            .iter
            .as_slice()
            .first()
            .map(SocketAddr::is_ipv6)
            .unwrap_or(false);
        self.iter
            .partition::<Vec<_>, _>(|addr| addr.is_ipv6() == is_ipv6)
    }
}

impl Iterator for ResolvedAddrs {
    type Item = SocketAddr;

    fn next(&mut self) -> Option<Self::Item> {
        self.iter.next()
    }
}

/// Futures generated by `DefaultDnsResolver`.
pub struct DefaultDnsFuture {
    inner: JoinHandle<Result<ResolvedAddrs, io::Error>>,
}

impl DefaultDnsFuture {
    pub(crate) fn new(handle: JoinHandle<Result<ResolvedAddrs, io::Error>>) -> Self {
        DefaultDnsFuture { inner: handle }
    }
}

impl Future for DefaultDnsFuture {
    type Output = Result<Addrs, StdError>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        Pin::new(&mut self.inner).poll(cx).map(|res| match res {
            Ok(Ok(addrs)) => Ok(Box::new(addrs) as Addrs),
            Ok(Err(err)) => Err(Box::new(err) as StdError),
            Err(err) => Err(Box::new(io::Error::new(io::ErrorKind::Interrupted, err)) as StdError),
        })
    }
}

pub(crate) struct DnsManager {
    /// Cache storing authority and DNS results
    pub(crate) map: Arc<Mutex<HashMap<String, DnsResult>>>,
    max_entries_len: usize,
    /// Time-to-live for the DNS cache
    pub(crate) ttl: Duration,
}

impl Default for DnsManager {
    // Default constructor for `DnsManager`, with a default TTL of 60
    // seconds.
    fn default() -> Self {
        DnsManager {
            map: Default::default(),
            max_entries_len: DEFAULT_MAX_LEN,
            ttl: DEFAULT_TTL, // Default TTL set to 60 seconds
        }
    }
}

impl DnsManager {
    /// Global DNS Manager
    pub(crate) fn global_dns_manager() -> Arc<Mutex<DnsManager>> {
        static GLOBAL_DNS_MANAGER: OnceLock<Arc<Mutex<DnsManager>>> = OnceLock::new();
        GLOBAL_DNS_MANAGER
            .get_or_init(|| Arc::new(Mutex::new(DnsManager::default())))
            .clone()
    }

    /// Cleans expired DNS cache entries by retaining only valid ones
    pub(crate) fn clean_expired_entries(&self) {
        let mut map_lock = self.map.lock().unwrap();
        if map_lock.len() > self.max_entries_len {
            map_lock.retain(|_, result| result.is_valid());
        }
    }
}

#[derive(Clone)]
pub(crate) struct DnsResult {
    /// List of resolved addresses for the authority
    pub(crate) addr: Vec<SocketAddr>,
    /// Expiration time for the cache entry
    expiration_time: Instant,
}

impl DnsResult {
    /// Creates a new DNS result with the given addresses and expiration time
    pub(crate) fn new(addr: Vec<SocketAddr>, expiration_time: Instant) -> Self {
        DnsResult {
            addr,
            expiration_time,
        }
    }

    /// Checks if the DNS result is still valid
    pub(crate) fn is_valid(&self) -> bool {
        self.expiration_time > Instant::now()
    }
}

impl Default for DnsResult {
    // Default constructor for `DnsResult`, with an empty address list and 60
    // seconds expiration
    fn default() -> Self {
        DnsResult {
            addr: vec![],
            expiration_time: Instant::now() + DEFAULT_TTL,
        }
    }
}

#[cfg(test)]
mod ut_resover_test {
    use super::*;

    /// UT test case for `DnsManager::new`
    ///
    /// # Brief
    /// 1. Creates a new `DnsManager` instance.
    /// 2. Verifies the default `max_entries_len` is 30000.
    /// 3. Sets and verifies a new `max_entries_len` of 1.
    #[test]
    fn ut_dns_manager_new() {
        let manager = DnsManager::default();
        assert_eq!(manager.max_entries_len, 30000);
        let mut map = manager.map.lock().unwrap();
        map.insert(
            "example.com".to_string(),
            DnsResult::new(vec![SocketAddr::from(([0, 0, 0, 1], 1))], Instant::now()),
        );
        assert!(map.contains_key("example.com"));
    }

    /// UT test case for `DnsManager::clean_expired_entries`
    ///
    /// # Brief
    /// 1. Creates a `DnsManager` instance and sets `max_entries_len` to 1.
    /// 2. Adds two DNS results to the cache: one valid and one expired.
    /// 3. Calls `clean_expired_entries` to remove expired entries.
    /// 4. Verifies the expired entry is removed from the cache.
    #[test]
    fn ut_dns_manager_clean_cache() {
        let manager = DnsManager {
            max_entries_len: 1,
            ..Default::default()
        };
        let mut map = manager.map.lock().unwrap();
        map.insert(
            "example1.com".to_string(),
            DnsResult::new(
                vec![SocketAddr::from(([0, 0, 0, 1], 1))],
                Instant::now() + Duration::from_secs(60),
            ),
        );
        map.insert(
            "example2.com".to_string(),
            DnsResult::new(
                vec![SocketAddr::from(([0, 0, 0, 2], 2))],
                Instant::now() - Duration::from_secs(60),
            ),
        );
        drop(map);
        manager.clean_expired_entries();
        assert!(manager.map.lock().unwrap().contains_key("example1.com"));
        assert!(!manager.map.lock().unwrap().contains_key("example2.com"));
    }
}
